#!/usr/bin/env python3.10

import importlib.machinery
import importlib.util
import inspect as i
import json
import logging
import re
import tempfile as tf
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Union, get_args, get_origin

import datamodel_code_generator as d
import requests
from pydantic import BaseModel, Field, create_model
from pydantic.fields import ModelField

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)

# Name conversion patterns
pat1 = re.compile("(.)([A-Z][a-z]+)")
pat2 = re.compile("([a-z0-9])([A-Z])")


def camel_to_snake(name):
    name = pat1.sub(r"\1_\2", name)
    return pat2.sub(r"\1_\2", name).lower()


class FieldSummary(BaseModel):
    type_name: str
    type: type[BaseModel]


class ModelSummary(BaseModel):
    type: type[BaseModel]
    type_name: str
    fields: dict[str, FieldSummary]
    description: str | None


class DAPTypesGenerator:
    def __init__(self) -> None:
        self.known_types = set()
        self.create_module()

    def create_module(self):
        protocol_path = Path("/tmp/dap-protocol.json")
        if not protocol_path.exists():
            logger.info("Downloading protocol")
            protocol_path.write_bytes(
                requests.get(
                    "https://raw.githubusercontent.com/microsoft/debug-adapter-protocol/gh-pages/debugAdapterProtocol.json"
                ).content
            )

        self.schema = json.loads(protocol_path.read_bytes().decode())
        _, output_path = tf.mkstemp()
        output_path = Path(output_path)

        d.generate(
            input_=protocol_path, output=output_path, use_schema_description=True
        )
        with open(output_path) as f:
            data = f.readlines()
        # Remove the update_forward_refs calls which cause errors
        while "update_forward_refs" in data[-1]:
            data.pop()

        with open(output_path, "w") as f:
            f.writelines(data)
        with open("model.py", "w") as f:
            f.writelines(data)

        # Import mymodule
        loader = importlib.machinery.SourceFileLoader("models", str(output_path))
        spec = importlib.util.spec_from_loader("models", loader)
        models_module = importlib.util.module_from_spec(spec)
        loader.exec_module(models_module)
        self.models = models_module

    def is_model(self, t) -> bool:
        return i.isclass(t) and issubclass(t, BaseModel)

    def extract_summary(
        self,
        model_cls: type[BaseModel],
    ) -> ModelSummary:
        fields = {}
        for field_name, field in model_cls.__fields__.items():
            if field and self.is_model(field.outer_type_):
                fields[field_name] = FieldSummary(
                    type=field.type_,
                    type_name=self.type_name(field.type_),
                )
        return ModelSummary(
            type=model_cls,
            type_name=self.type_name(model_cls),
            fields=fields,
            description=self.class_doc(model_cls),
        )

    def type_name(self, t) -> str:
        if self.is_model(t):
            return f"dapui.types.{t.__name__}"
        if t is int:
            return "integer"
        if t is float:
            return "number"
        if t is str:
            return "string"
        if t is bool:
            return "boolean"
        if i.isclass(t) and issubclass(t, Enum):
            return "|".join(f'"{member.value}"' for member in t.__members__.values())
        if get_origin(t) is Union:
            # Account for Optional with ? on field
            return " | ".join(
                self.type_name(a) for a in get_args(t) if a is not type(None)
            )
        if get_origin(t) is list:
            return f"{self.type_name(get_args(t)[0])}[]"
        if t is list:
            return f"any[]"
        if get_origin(t) is dict:
            return f"table<{self.type_name(get_args(t)[0])},{self.type_name(get_args(t)[1])}>"
        if t is dict:
            return f"table"
        if t is Any:
            return "any"
        if t is type(None):
            return "nil"
        raise Exception(f"Unknown type {t}")

    def safe_name(self, name: str) -> str:
        if name == "goto":
            return "goto_"
        return name

    def prepare_doc(self, doc: str, multiline: bool):
        lines = doc.strip().splitlines()
        if len(lines) == 1:
            return lines[0]
        if multiline:
            return f"{lines[0]}\n" + "\n".join(
                f"--- {line.strip()}" for line in lines[1:]
            )
        return " ".join(line.strip() for line in lines)

    def field_annotation(self, field: ModelField) -> str:
        description = field.field_info.description and self.prepare_doc(
            field.field_info.description, multiline=False
        )
        return f"---@field {field.name}{'?' if not field.required else ''} {self.type_name(field.outer_type_)} {description or ''}"

    def class_doc(self, model_cls: type[BaseModel]) -> str | None:
        if model_cls.__doc__:
            return model_cls.__doc__
        cls_schema = self.schema["definitions"].get(model_cls.__name__)
        if not cls_schema:
            return
        if sub_types := cls_schema.get("allOf"):
            for definition in sub_types:
                if sub_desc := definition.get("description"):
                    return sub_desc

    def create_class(self, t, class_name: str | None = None) -> list[str]:
        sub_classes = []
        lines = []
        t_name = class_name or self.type_name(t)
        if self.is_model(t) and t_name not in self.known_types:
            self.known_types.add(t_name)
            for field in t.__fields__.values():
                if field_sub_classes := self.create_class(field.outer_type_):
                    sub_classes += field_sub_classes
                    sub_classes.append("")
                lines.append(self.field_annotation(field))
            class_doc = self.class_doc(t)
            class_prefix = (
                [f"--- {self.prepare_doc(class_doc, multiline=True)}"]
                if class_doc
                else []
            )
            class_prefix.append(f"---@class {t_name}")
            lines = [
                *class_prefix,
                *lines,
            ]
        else:
            try:
                sub_classes = [
                    line
                    for sub_type in get_args(t)
                    for line in self.create_class(sub_type)
                ]
            except TypeError:
                ...
        return [
            *sub_classes,
            *lines,
        ]

    def create_request(
        self, request_cls: type[BaseModel], response_cls: type[BaseModel]
    ) -> list[str]:
        (command,) = list(request_cls.__fields__["command"].outer_type_.__members__)

        request_summary = self.extract_summary(request_cls)
        types = []
        signature = []
        if request_summary.description:
            signature.append(
                f"--- {self.prepare_doc(request_summary.description, multiline=True)}"
            )
        signature.append("---@async")
        if arg_summary := request_summary.fields.get("arguments"):
            types = self.create_class(arg_summary.type)
            signature += [
                f"---@param args {arg_summary.type_name}",
                f"function DAPUIRequestsClient.{self.safe_name(command)}(args) end",
            ]
        else:
            signature += [
                f"function DAPUIRequestsClient.{self.safe_name(command)}() end",
            ]

        response_summary = self.extract_summary(response_cls)
        if body_summary := response_summary.fields.get("body"):
            types.append("")
            types += self.create_class(body_summary.type, response_summary.type_name)
            signature.insert(-1, f"---@return {response_summary.type_name}")

        x = [
            *types,
            "",
            *signature,
            "",
            "",
        ]
        return x

    def create_event_listener(self, event_cls: type[BaseModel]) -> list[str]:
        (event,) = list(event_cls.__fields__["event"].outer_type_.__members__)
        event_summary = self.extract_summary(event_cls)
        types = []
        signature = []
        if event_summary.description:
            signature.append(
                f"--- {self.prepare_doc(event_summary.description, multiline=True)}"
            )
        if body_summary := event_summary.fields.get("body"):
            body_class_name = f"{event_summary.type_name}Args"
            types = self.create_class(body_summary.type, body_class_name)
            signature.append(f"---@param listener fun(args: {body_class_name})")
        else:
            signature.append(f"---@param listener fun()")

        signature.append(f"---@param opts? dapui.client.ListenerOpts")
        x = [
            *types,
            "",
            *signature,
            f"function DAPUIEventListenerClient.{self.safe_name(event)}(listener, opts) end",
            "",
            "",
        ]
        return x

    def create_request_listener(
        self, request_cls: type[BaseModel], response_cls: type[BaseModel]
    ) -> list[str]:
        (command,) = list(request_cls.__fields__["command"].outer_type_.__members__)

        request_summary = self.extract_summary(request_cls)
        response_summary = self.extract_summary(response_cls)

        listener_args = {}
        if args := request_summary.fields.get("arguments"):
            listener_args["request"] = args.type
        listener_args["error"] = dict | None
        if args := response_summary.fields.get("body"):
            listener_args["response"] = create_model(
                response_cls.__name__, __base__=args.type
            )

        args_summary = self.extract_summary(
            create_model(
                f"{command}RequestListenerArgs",
                **{name: (t, Field()) for name, t in listener_args.items()},
            )
        )
        types = self.create_class(args_summary.type)
        return [
            *types,
            "",
            f"---@param listener fun(args: {args_summary.type_name}): boolean | nil",
            f"---@param opts? dapui.client.ListenerOpts",
            f"function DAPUIEventListenerClient.{self.safe_name(command)}(listener, opts) end",
            "",
            "",
        ]

    PREFIX = """
---@class dapui.DAPRequestsClient
local DAPUIRequestsClient = {}

---@class dapui.DAPEventListenerClient
local DAPUIEventListenerClient = {}

---@class dapui.client.ListenerOpts
---@field before boolean Run before event/request is processed by nvim-dap
"""

    def generate_types(self) -> str:
        output = f"--- Generated on {datetime.utcnow()}\n"
        output += self.PREFIX

        member_names = dict(i.getmembers(self.models))
        for _, member in i.getmembers(self.models, self.is_model):
            member.update_forward_refs(**member_names)
        for name, request in i.getmembers(
            self.models,
            lambda x: x is not self.models.Request
            and i.isclass(x)
            and issubclass(x, self.models.Request),
        ):
            try:
                output += "\n".join(
                    self.create_request(
                        request,
                        getattr(self.models, name.replace("Request", "Response")),
                    )
                )
                output += "\n".join(
                    self.create_request_listener(
                        request,
                        getattr(self.models, name.replace("Request", "Response")),
                    )
                )
            except Exception as e:
                logger.exception(f"Failed to create {name}: {e}")

        for name, event in i.getmembers(
            self.models,
            lambda x: x is not self.models.Event
            and i.isclass(x)
            and issubclass(x, self.models.Event),
        ):
            try:
                output += "\n".join(self.create_event_listener(event))
            except Exception as e:
                logger.exception(f"Failed to create {name}: {e}")
        output += "\nreturn { request = DAPUIRequestsClient, listen = DAPUIEventListenerClient }"
        return output


generator = DAPTypesGenerator()
logger.info("Generating types")

types = generator.generate_types()

logger.info("Outputting types")
print(types)
